﻿using Dapper;
using System;
using System.Collections.Generic;
using System.Data;
using System.Data.SqlClient;
using System.Linq;
using System.Threading.Tasks;

namespace General.Entities.Dapper
{
    public class DbRepository
    {
        private IDbConnection _connection { get; set; }
        public string dbConnKey { get; set; }
        public DbType dbType { get; set; }

        public DbRepository(string dbConnKey, DbType dbType = DbType.SqlServer)
        {
            this.dbConnKey = dbConnKey;
            this.dbType = dbType;
        }
        public DbRepository()
        {
            this.dbConnKey = "SqlServer";
        }

        public void Init(string dbConnKey, DbType dbType = DbType.SqlServer)
        {
            this.dbConnKey = dbConnKey;
            this.dbType = dbType;
        }

        #region 属性

        /// <summary>
        /// 获取数据库连接
        /// </summary>
        public IDbConnection Connection
        {
            get
            {
                if (_connection == null)
                {
                    _connection = DbConnectionFactory.GetConnection(dbConnKey, dbType);
                }
                return _connection;
            }
        }

        /// <summary>
        /// 事务对象
        /// </summary>
        public IDbTransaction dbTransaction { get; set; }

        #endregion

        #region 事务提交

        /// <summary>
        /// 事务开始
        /// </summary>
        /// <returns></returns>
        public DbRepository BeginTrans()
        {
            dbTransaction = Connection.BeginTransaction();
            return this;
        }

        /// <summary>
        /// 提交当前操作的结果
        /// </summary>
        public int Commit()
        {
            try
            {
                if (dbTransaction != null)
                {
                    dbTransaction.Commit();
                    this.Close();
                }
                return 1;
            }
            catch (Exception ex)
            {
                if (ex.InnerException != null && ex.InnerException.InnerException is Microsoft.Data.SqlClient.SqlException)
                {
                    Microsoft.Data.SqlClient.SqlException sqlEx = ex.InnerException.InnerException as Microsoft.Data.SqlClient.SqlException;
                }
                throw;
            }
            finally
            {
                if (dbTransaction == null)
                {
                    this.Close();
                }
            }
        }

        /// <summary>
        /// 把当前操作回滚成未提交状态
        /// </summary>
        public void Rollback()
        {
            this.dbTransaction.Rollback();
            this.dbTransaction.Dispose();
            this.Close();
        }

        /// <summary>
        /// 关闭连接 内存回收
        /// </summary>
        public void Close()
        {
            IDbConnection dbConnection = dbTransaction.Connection;
            if (dbConnection != null && dbConnection.State != ConnectionState.Closed)
            {
                dbConnection.Close();
            }

        }

        #endregion

        #region 实例方法

        #region 查询

        /// <summary>
        /// 查询
        /// </summary>
        /// <typeparam name="T">返回类型</typeparam>
        /// <param name="sql">sql语句</param>
        /// <param name="dbConnKey">数据库连接</param>
        /// <param name="param">sql查询参数</param>
        /// <param name="commandTimeout">超时时间</param>
        /// <param name="commandType">命令类型</param>
        /// <returns></returns>
        public T QueryFirst<T>(string sql, object param = null, bool buffered = true, int? commandTimeout = default(int?), CommandType? commandType = default(CommandType?))
        {
            if (dbTransaction == null)
            {
                using (var dbConn = Connection)
                {
                    return dbConn.QueryFirstOrDefault<T>(sql, param, null, commandTimeout, commandType);
                }
            }
            else
            {
                return dbTransaction.Connection.QueryFirstOrDefault<T>(sql, param, dbTransaction, commandTimeout, commandType);
            }

        }

        /// <summary>
        /// 查询(异步版本)
        /// </summary>
        /// <typeparam name="T">返回类型</typeparam>
        /// <param name="sql">sql语句</param>
        /// <param name="dbConnKey">数据库连接</param>
        /// <param name="param">sql查询参数</param>
        /// <param name="commandTimeout">超时时间</param>
        /// <param name="commandType">命令类型</param>
        /// <returns></returns>
        public async Task<T> QueryFirstAsync<T>(string sql, object param = null, bool buffered = true, int? commandTimeout = default(int?), CommandType? commandType = default(CommandType?))
        {
            if (dbTransaction == null)
            {
                using (var dbConn = Connection)
                {
                    return await dbConn.QueryFirstOrDefaultAsync<T>(sql, param, null, commandTimeout, commandType);
                }
            }
            else
            {
                return await dbTransaction.Connection.QueryFirstOrDefaultAsync<T>(sql, param, dbTransaction, commandTimeout, commandType);
            }

        }


        /// <summary>
        /// 查询
        /// </summary>
        /// <typeparam name="T">返回类型</typeparam>
        /// <param name="sql">sql语句</param>
        /// <param name="dbConnKey">数据库连接</param>
        /// <param name="param">sql查询参数</param>
        /// <param name="buffered">是否缓冲</param>
        /// <param name="commandTimeout">超时时间</param>
        /// <param name="commandType">命令类型</param>
        /// <returns></returns>
        public IEnumerable<T> Query<T>(string sql, object param = null, bool buffered = true, int? commandTimeout = default(int?), CommandType? commandType = default(CommandType?))
        {
            if (dbTransaction == null)
            {
                using (var dbConn = Connection)
                {
                    return dbConn.Query<T>(sql, param, null, buffered, commandTimeout, commandType);
                }
            }
            else
            {
                return dbTransaction.Connection.Query<T>(sql, param, dbTransaction, buffered, commandTimeout, commandType);
            }

        }


        /// <summary>
        /// 查询(异步版本)
        /// </summary>
        /// <typeparam name="T">返回类型</typeparam>
        /// <param name="sql">sql语句</param>
        /// <param name="dbConnKey">数据库连接</param>
        /// <param name="param">sql查询参数</param>
        /// <param name="buffered">是否缓冲</param>
        /// <param name="commandTimeout">超时时间</param>
        /// <param name="commandType">命令类型</param>
        /// <returns></returns>
        public async Task<IEnumerable<T>> QueryAsync<T>(string sql, object param = null, bool buffered = true, int? commandTimeout = default(int?), CommandType? commandType = default(CommandType?))
        {
            if (dbTransaction == null)
            {
                using (var dbConn = Connection)
                {
                    return await dbConn.QueryAsync<T>(sql, param, null, commandTimeout, commandType);
                }
            }
            else
            {
                return await dbTransaction.Connection.QueryAsync<T>(sql, param, dbTransaction, commandTimeout, commandType);
            }

        }



        /// <summary>
        /// 查询返回 IDataReader
        /// </summary>
        /// <param name="sql">sql语句</param>
        /// <param name="dbConnKey">数据库连接</param>
        /// <param name="param">sql查询参数</param>
        /// <param name="commandTimeout">超时时间</param>
        /// <param name="commandType">命令类型</param>
        /// <returns></returns>
        public IDataReader ExecuteReader(string sql, object param = null, int? commandTimeout = default(int?), CommandType? commandType = default(CommandType?))
        {
            if (dbTransaction == null)
            {
                using (var dbConn = Connection)
                {
                    return dbConn.ExecuteReader(sql, param, null, commandTimeout, commandType);
                }
            }
            else
            {
                return dbTransaction.Connection.ExecuteReader(sql, param, dbTransaction, commandTimeout, commandType);
            }
        }

        /// <summary>
        /// 查询单个返回值 
        /// </summary>
        /// <typeparam name="T">返回类型</typeparam>
        /// <param name="sql">sql语句</param>
        /// <param name="dbConnKey">数据库连接</param>
        /// <param name="param">sql查询参数</param>
        /// <param name="commandTimeout">超时时间</param>
        /// <param name="commandType">命令类型</param>
        /// <returns></returns>
        public T ExecuteScalar<T>(string sql, object param = null, int? commandTimeout = default(int?), CommandType? commandType = default(CommandType?))
        {
            if (dbTransaction == null)
            {
                using (var dbConn = Connection)
                {
                    return dbConn.ExecuteScalar<T>(sql, param, null, commandTimeout, commandType);
                }
            }
            else
            {
                return dbTransaction.Connection.ExecuteScalar<T>(sql, param, dbTransaction, commandTimeout, commandType);
            }

        }
        #endregion

        /// <summary>
        /// 执行增删改sql
        /// </summary>
        /// <param name="sql">sql</param>
        /// <param name="dbkey">数据库连接</param>
        /// <param name="param">sql查询参数</param>
        /// <param name="commandTimeout">超时时间</param>
        /// <param name="commandType">命令类型</param>
        /// <returns></returns>
        public int ExecuteSql(string sql, object param = null, int? commandTimeout = default(int?), CommandType? commandType = default(CommandType?))
        {
            if (dbTransaction == null)
            {
                using (var dbConn = Connection)
                {
                    return dbConn.Execute(sql, param, null, commandTimeout, commandType);
                }
            }
            else
            {
                return dbTransaction.Connection.Execute(sql, param, dbTransaction);
            }
        }

        /// <summary>
        /// 执行增删改sql(异步版本)
        /// </summary>
        /// <param name="sql">sql</param>
        /// <param name="dbkey">数据库连接</param>
        /// <param name="param">sql查询参数</param>
        /// <param name="commandTimeout">超时时间</param>
        /// <param name="commandType">命令类型</param>
        /// <returns></returns>
        public async Task<int> ExecuteSqlAsync(string sql, object param = null, int? commandTimeout = default(int?), CommandType? commandType = default(CommandType?))
        {
            if (dbTransaction == null)
            {
                using (var dbConn = Connection)
                {
                    return await dbConn.ExecuteAsync(sql, param, null, commandTimeout, commandType);
                }
            }
            else
            {
                await dbTransaction.Connection.ExecuteAsync(sql, param, dbTransaction);
                return 0;
            }
        }


        #endregion

        #region 静态方法

        #region 查询
        /// <summary>
        /// 查询
        /// </summary>
        /// <typeparam name="T">返回类型</typeparam>
        /// <param name="sql">sql语句</param>
        /// <param name="dbConnKey">数据库连接</param>
        /// <param name="param">sql查询参数</param>
        /// <param name="commandTimeout">超时时间</param>
        /// <param name="commandType">命令类型</param>
        /// <returns></returns>
        public static T QueryFirst<T>(string sql, string dbConnKey, object param = null, DbType dbType = DbType.SqlServer, bool buffered = true, int? commandTimeout = default(int?), CommandType? commandType = default(CommandType?))
        {
            using (var dbConn = DbConnectionFactory.GetConnection(dbConnKey, dbType))
            {
                return dbConn.QueryFirstOrDefault<T>(sql, param, null, commandTimeout, commandType);
            }
        }

        /// <summary>
        /// 查询(异步版本)
        /// </summary>
        /// <typeparam name="T">返回类型</typeparam>
        /// <param name="sql">sql语句</param>
        /// <param name="dbConnKey">数据库连接</param>
        /// <param name="param">sql查询参数</param>
        /// <param name="commandTimeout">超时时间</param>
        /// <param name="commandType">命令类型</param>
        /// <returns></returns>
        public static async Task<T> QueryFirstAsync<T>(string sql, string dbConnKey, object param = null, DbType dbType = DbType.SqlServer, bool buffered = true, int? commandTimeout = default(int?), CommandType? commandType = default(CommandType?))
        {
            using (var dbConn = DbConnectionFactory.GetConnection(dbConnKey, dbType))
            {
                return await dbConn.QueryFirstOrDefaultAsync<T>(sql, param, null, commandTimeout, commandType);
            }
        }


        /// <summary>
        /// 查询
        /// </summary>
        /// <typeparam name="T">返回类型</typeparam>
        /// <param name="sql">sql语句</param>
        /// <param name="dbConnKey">数据库连接</param>
        /// <param name="param">sql查询参数</param>
        /// <param name="buffered">是否缓冲</param>
        /// <param name="commandTimeout">超时时间</param>
        /// <param name="commandType">命令类型</param>
        /// <returns></returns>
        public static IEnumerable<T> Query<T>(string sql, string dbConnKey, object param = null, DbType dbType = DbType.SqlServer, bool buffered = true, int? commandTimeout = default(int?), CommandType? commandType = default(CommandType?))
        {
            using (var dbConn = DbConnectionFactory.GetConnection(dbConnKey, dbType))
            {
                return dbConn.Query<T>(sql, param, null, buffered, commandTimeout, commandType);
            }
        }


        /// <summary>
        /// 查询(异步版本)
        /// </summary>
        /// <typeparam name="T">返回类型</typeparam>
        /// <param name="sql">sql语句</param>
        /// <param name="dbConnKey">数据库连接</param>
        /// <param name="param">sql查询参数</param>
        /// <param name="buffered">是否缓冲</param>
        /// <param name="commandTimeout">超时时间</param>
        /// <param name="commandType">命令类型</param>
        /// <returns></returns>
        public static async Task<IEnumerable<T>> QueryAsync<T>(string sql, string dbConnKey, object param = null, DbType dbType = DbType.SqlServer, bool buffered = true, int? commandTimeout = default(int?), CommandType? commandType = default(CommandType?))
        {
            using (var dbConn = DbConnectionFactory.GetConnection(dbConnKey, dbType))
            {
                return await dbConn.QueryAsync<T>(sql, param, null, commandTimeout, commandType);
            }
        }



        /// <summary>
        /// 查询返回 IDataReader
        /// </summary>
        /// <param name="sql">sql语句</param>
        /// <param name="dbConnKey">数据库连接</param>
        /// <param name="param">sql查询参数</param>
        /// <param name="commandTimeout">超时时间</param>
        /// <param name="commandType">命令类型</param>
        /// <returns></returns>
        public static IDataReader ExecuteReader(string sql, string dbConnKey, object param = null, DbType dbType = DbType.SqlServer, int? commandTimeout = default(int?), CommandType? commandType = default(CommandType?))
        {
            using (var dbConn = DbConnectionFactory.GetConnection(dbConnKey, dbType))
            {
                return dbConn.ExecuteReader(sql, param, null, commandTimeout, commandType);
            }
        }

        /// <summary>
        /// 查询单个返回值 
        /// </summary>
        /// <typeparam name="T">返回类型</typeparam>
        /// <param name="sql">sql语句</param>
        /// <param name="dbConnKey">数据库连接</param>
        /// <param name="param">sql查询参数</param>
        /// <param name="commandTimeout">超时时间</param>
        /// <param name="commandType">命令类型</param>
        /// <returns></returns>
        public static T ExecuteScalar<T>(string sql, string dbConnKey, object param = null, DbType dbType = DbType.SqlServer, int? commandTimeout = default(int?), CommandType? commandType = default(CommandType?))
        {
            using (var dbConn = DbConnectionFactory.GetConnection(dbConnKey, dbType))
            {
                return dbConn.ExecuteScalar<T>(sql, param, null, commandTimeout, commandType);
            }
        }

        #endregion

        #region 增删改

        /// <summary>
        /// 执行增删改sql
        /// </summary>
        /// <param name="sql">sql</param>
        /// <param name="dbkey">数据库连接</param>
        /// <param name="param">sql查询参数</param>
        /// <param name="commandTimeout">超时时间</param>
        /// <param name="commandType">命令类型</param>
        /// <returns></returns>
        public static int Execute(string sql, string dbkey, object param = null, DbType dbType = DbType.SqlServer, int? commandTimeout = default(int?), CommandType? commandType = default(CommandType?))
        {
            using (var dbConn = DbConnectionFactory.GetConnection(dbkey, dbType))
            {
                return dbConn.Execute(sql, param, null, commandTimeout, commandType);
            }
        }

        /// <summary>
        /// 执行增删改sql(异步版本)
        /// </summary>
        /// <param name="sql">sql</param>
        /// <param name="dbkey">数据库连接</param>
        /// <param name="param">sql查询参数</param>
        /// <param name="commandTimeout">超时时间</param>
        /// <param name="commandType">命令类型</param>
        /// <returns></returns>
        public static async Task<int> ExecuteAsync(string sql, string dbkey, object param = null, DbType dbType = DbType.SqlServer, int? commandTimeout = default(int?), CommandType? commandType = default(CommandType?))
        {
            using (var dbConn = DbConnectionFactory.GetConnection(dbkey, dbType))
            {
                return await dbConn.ExecuteAsync(sql, param, null, commandTimeout, commandType);
            }
        }


        /// <summary>
        /// 执行 DynamicQuery.GetInsertQuery* 方法生成的Sql 返回标识值
        /// </summary>
        /// <typeparam name="T"></typeparam>
        /// <param name="sql"></param>
        /// <param name="dbKey"></param>
        /// <param name="param"></param>
        /// <param name="commandTimeout"></param>
        /// <param name="commandType"></param>
        /// <returns></returns>
        public static T ExecuteInsertSql<T>(string sql, string dbKey, object param = null, DbType dbType = DbType.SqlServer, int? commandTimeout = default(int?), CommandType? commandType = default(CommandType?))
        {
            using (var dbConn = DbConnectionFactory.GetConnection(dbKey, dbType))
            {
                return dbConn.QueryFirstOrDefault<T>(sql, param, null, commandTimeout, commandType);
            }
        }

        #endregion

        #endregion

    }
}
